# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Support to generate the latent variables for the squiggles dataset in Jax."""
import enum
import math

from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import jax
from jax import random
import jax.numpy as jnp
from shapely import geometry
import tensorflow as tf
import tensorflow_datasets as tfds

from squiggles import space_mappings


class LatentSpace(enum.Enum):
  UNDEFINED = enum.auto()
  TAYLOR = enum.auto()
  SINE_NET = enum.auto()

  @classmethod
  def all_spaces(cls):
    return [a.name for a in cls]


LATENT_SPACE_TO_DEFAULT_HIDDEN_SIZE: Dict[LatentSpace, int] = {
    LatentSpace.SINE_NET: 25,  # original paper used 300
    LatentSpace.TAYLOR: 7,
}


def _sine_net_latent_and_coords(
    random_key, hidden_size,
    num_points):
  """Generates sine_net data.

  The function pseudorandomly generates a  point in SineNet latent
  space and returns that point together with its translation into a series of
  points on a curve, rescaled to fit inside a [0, 1] x [0, 1] square.

  Args:
    random_key: The key for the random seed
    hidden_size: The number of sinusoidal functions used in the generation.
    num_points: The number of points on the curve to compute

  Returns:
    The latent represenation [hidden_size x 4] and
    the coordinate representation [num_points x 2].
  """
  # Generate Jax keys
  frequency_key, phase_shift_key, xy_amplitudes_key = random.split(
      random_key, 3)

  # Generate the latent variables
  frequency_tensor = random.normal(frequency_key, shape=(hidden_size, 1))
  phase_shift_tensor = random.uniform(
      key=phase_shift_key, shape=(hidden_size, 1), minval=0, maxval=2 * math.pi)
  xy_amplitudes_tensor = random.normal(
      xy_amplitudes_key, shape=(hidden_size, 2))
  input_tensor = jnp.concatenate(
      [frequency_tensor, phase_shift_tensor, xy_amplitudes_tensor], axis=-1)

  # Generate the coordinate representation
  tmp = space_mappings.sine_net_to_path_points(input_tensor[jnp.newaxis, Ellipsis],
                                               num_points)
  [coords_tensor] = space_mappings.rescale_points(tmp)
  return input_tensor, coords_tensor


def _self_intersects(points):
  # assumption: points.shape == [something, 2]
  return not geometry.LineString(points).is_simple


def _taylor_latent_and_coords(
    random_key, hidden_size,
    num_points):
  """Generates taylor data.

  The function pseudorandomly generates a  point in latent space and
  returns that point together with its translation into a series of points on a
  curve, rescaled to fit inside a [0, 1] x [0, 1] square.

  Args:
    random_key: The key for the random seed
    hidden_size: The number of latent parameters per dimension (x/y).
    num_points: The number of points on the curve to compute.

  Returns:
    The latent represenation [2 x hidden_size] and
    the coordinate representation [num_points x 2].
  """
  # Generate the latent variables
  derivs_tensor = random.uniform(
      key=random_key, shape=(2, hidden_size), minval=-1.0, maxval=1.0)

  # Generate the coordinate representation
  tmp = space_mappings.derivs_to_path_points(derivs_tensor[jnp.newaxis, Ellipsis],
                                             num_points)
  [coords_tensor] = space_mappings.rescale_points(tmp)
  return derivs_tensor, coords_tensor


def generate_dataset(
    latent_space, start_seed, end_seed,
    dataset_code, hidden_size,
    num_points):
  """Generate samples for a Squiggles dataset.

  The seeds for the individual samples are generated by computing an
  adjusted seed. The seed ranges between [start_seed, end_seed[ and is
  shifted by 4 bits before adding the dataset code. Under the assumption
  that the random numbers generated at a different seed should be distinct
  this enables the generation of non-overlapping train, validation and test
  splits by using the dataset code.

  Args:
    latent_space: the latent space that needs to be used for generation. This is
      either TAYLOR or SINE_NET.
    start_seed: the seed where to start generating samples.
    end_seed: the seed where to stop generating samples.
    dataset_code: the dataset code is used to make sure samples do not overlap
      between train and test set. The permissible values are [0, 16[.
    hidden_size: the size of the latent space (up to a constant factor).
    num_points: the number of coordinates to generate and use for the labelling.

  Returns:
    The latent represenation as list of datapoints each with dimension
    [2 x hidden_size] for TAYLOR or [hidden_size x 4] for SINE_NET.
    The coordinate representation as a list of datapoints each with dimension
    [num_points x 2].
    The labels as a list of datapoints.
  """
  if latent_space == LatentSpace.TAYLOR:
    data_fn = _taylor_latent_and_coords
  elif latent_space == LatentSpace.SINE_NET:
    data_fn = _sine_net_latent_and_coords
  else:
    raise ValueError(f'Unrecognized LatentSpace: {latent_space}')

  # The assumption is that when we start at a certain seed, the sequence
  # that follows will be different than starting at another seed.
  # The last two 4 bits encode the dataset id. This should allow us to generate
  # 16 datasets that are distinct to enable proper train, validation and test
  # splits.
  if dataset_code < 0 or dataset_code >= 16:
    raise ValueError(
        f'The dataset_code value must be between 0 (inclusive) and 16 (exclusive): {dataset_code}'
    )

  latents = []
  coordinates = []
  labels = []
  for seed in range(start_seed, end_seed):
    adjusted_seed = seed << 4 + dataset_code
    latent, coordinate = data_fn(
        random.PRNGKey(adjusted_seed), hidden_size, num_points)
    latents.append(latent)
    coordinates.append(coordinate)
    labels.append(_self_intersects(coordinate))
  return latents, coordinates, labels


def _float_list_feature(data):
  return tf.train.Feature(float_list=tf.train.FloatList(value=data))


def _int_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def numpy_to_example(
    self_intersects,
    serial_no,
    *,  # Following arguments are keyword-only.
    taylor_params = None,  # keyword-only
    sinenet_params = None,  # keyword-only
    points = None,  # keyword-only
):
  """Structures curve data into a tf.Example."""
  if taylor_params is None and sinenet_params is None and points is None:
    raise ValueError(
        'At least one of taylor_params, sinenet_params, points should be '
        'populated.')
  if taylor_params is not None and sinenet_params is not None:
    raise ValueError(
        'At most one of taylor_params, sinenet_params can be populated.')
  features = dict(
      label=_int_feature(self_intersects),
      serial_no=_int_feature(serial_no),
  )
  if points is not None:
    features['inputs/points'] = _float_list_feature(points.flatten())
  if taylor_params is not None:
    features['inputs/taylor'] = _float_list_feature(
        taylor_params.transpose().flatten())
  if sinenet_params is not None:
    features['inputs/sinenet'] = _float_list_feature(sinenet_params.flatten())
  return tf.train.Example(features=tf.train.Features(feature=features))


def write_metadata(base_dir,
                   samples_per_shard,
                   num_shards_by_split,
                   *,
                   num_points = 100,
                   taylor_hidden_size = None,
                   sine_net_hidden_size = None):
  """Writes metadata files to make dataset easy to load."""
  inputs = {
      'points': tfds.features.Tensor(shape=(num_points, 2), dtype=tf.float32)
  }
  if taylor_hidden_size is not None:
    inputs['taylor'] = tfds.features.Tensor(
        shape=(taylor_hidden_size, 2), dtype=tf.float32)
  if sine_net_hidden_size is not None:
    inputs['sinenet'] = tfds.features.Tensor(
        shape=(sine_net_hidden_size, 4), dtype=tf.float32)
  features = tfds.features.FeaturesDict({
      # These are the features of your dataset like images, labels ...
      'inputs': tfds.features.Sequence(inputs),
      'label': tfds.features.ClassLabel(names=['simple', 'intersecting']),
  })
  data_dir = base_dir
  split_infos = []
  for split, num_shards in num_shards_by_split.items():
    split_infos.append(
        tfds.core.SplitInfo(
            name=split,
            shard_lengths=num_shards * [samples_per_shard],
            num_bytes=0,
        ))
  tfds.folder_dataset.write_metadata(
      data_dir=data_dir,
      features=features,
      split_infos=split_infos,
  )


def read_dataset(base_dir, split):
  builder = tfds.core.read_only_builder.builder_from_directory(base_dir)

  return builder.as_dataset(split=split)


def write_to_tfrecord(
    base_file,
    split,
    points,
    labels,
    *,
    shard_num,
    num_shards,
    taylor_latents = None,
    sinenet_latents = None,
):
  """Writes curve data to a tfrecord shard."""
  filename = f'{base_file}-{split}.tfrecord-{shard_num:05}-of-{num_shards:05}'
  num_records = len(labels)
  if taylor_latents is None:
    taylor_latents = [None for _ in range(num_records)]
  if sinenet_latents is None:
    sinenet_latents = [None for _ in range(num_records)]
  with tf.io.TFRecordWriter(filename) as writer:
    for i, (taylor_latent, sinenet_latent, pointlist, label) in enumerate(
        zip(taylor_latents, sinenet_latents, points, labels)):
      record_bytes = numpy_to_example(
          label,
          i,
          taylor_params=taylor_latent,
          sinenet_params=sinenet_latent,
          points=pointlist).SerializeToString()
      writer.write(record_bytes)
